
"""

AUTHOR: Marco Marchetti, Bruce A. Edgar Lab, Huntsman Cancer Institute, Salt Lake City, 84112, UT, USA

DESCRIPTION: This script is used to parse files generated by the "Source code 2" ImageJ macro. First, the lineage is reconstructed based
on relative cell distances between frames. For each cell in frame f, a new position in frame f+1 is assigned, based on the pool of annotated
cell positions in f+1. All possible permutations between f and f+1 coordinates are computed and the the total distance between cell positions
in frame f and their newly assigned positions in f+1 is calculated. The permutation for which this distance is minimum is then kept. If lineages
are too complex (e.g. cells are moving and rearranging themselves), then the user can add two additional field to the input ".tsv" file ("CellID"
and "MotherID" columns) indicating each cells ID (for each cell position) and its mother cell ID (can be stated once). The lineage data is then
parsed and formatted in an easy to read format, then exported to a ".tsv" file.

"""

def reconstructLineage_Manual(data):
    
    # Uses information manually defined by the user ("CellID" and "MotherID" columns) to reconstruct cells lineage
    id_index, mother_index, t1 = fields.index("CellID"), fields.index("MotherID"), fields.index("T") # Index of fields "CellID", "MotherID", "T"
    final_fields = [field for field in fields if field not in ["CellID", "MotherID"]] # "CellID", "MotherID" will be discarded
    t2 = final_fields.index("T")
    starting_cell, cell_number = int(min([d[id_index] for d in data])), int(max([d[id_index] for d in data]))
    stop = int(max(data, key = lambda d: d[t1])[t1]) # Last time-point
    lineage = []
    for i in range(starting_cell, cell_number + 1):
        mother_cell = [int(d[mother_index]) for d in data if d[id_index] == i and d[mother_index] != ""]
        cell_data = [[d[y] for y in range(len(d)) if y != id_index and y != mother_index] for d in data if d[id_index] == i]
        cell_data = sorted(cell_data, key = lambda cd: cd[t2])
        if len(mother_cell):
            # "mother_cell[0] - starting_cell" to make sure indexing works correctly even if user started counting from a number different than 0
            cell_data = ["Daughter of #" + str(mother_cell[0] - starting_cell)] + cell_data
        if cell_data[-1][t2] != stop:
            cell_data.append("Divided")
        lineage.append(cell_data)
        
    return lineage

def reconstructLineage_Auto(data, fields):
    
    # Extracting coordinates
    x, y, z, t = fields.index("X"), fields.index("Y"), fields.index("Z"), fields.index("T") # Index of fields "X", "Y", "Z", "T"
    
    # Lineage reconstruction: starting from earlier time-point (i.e. single cell state), the algorithm moves forward tracking the cell and its daughters
    distance = lambda a,b: ((a[x] - b[x])**2 + (a[y] - b[y])**2 + (a[z] - b[z])**2)**0.5
    start, stop = int(min(data, key = lambda d: d[t])[0]), int(max(data, key = lambda d: d[t])[0]) # First and last time-points
    lineage = []
    active_cells = [[d for d in data if d[t] == start]] # List of cells to look for in the next frame
    for frame in range(start + 1, stop + 1):
        
        # Defining old and new positions
        old = [l[-1] for l in active_cells]
        new = [d for d in data if d[t] == frame]
        
        # Calculating all possible matches between old and new positions, then selecting the one that results in smaller total distance
        permutations1 = list(perm(range(len(new)), len(old)))
        distances = [sum([distance(o, new[target]) for o,target in zip(old, p)]) for p in permutations1]
        best_match1 = distances.index(min(distances))
        for ac,p in zip(range(len(active_cells)), permutations1[best_match1]):
            active_cells[ac].append(new[p])
        
        # Finding new cells
        new_cells = [new[i] for i in range(len(new)) if i not in permutations1[best_match1]]
        if not len(new_cells):
            continue
        possible_sisters = [i for i in range(len(new)) if i in permutations1[best_match1]]
        permutations2 = list(perm(possible_sisters, len(new_cells)))
        distances = [sum([distance(nc, new[target]) for nc,target in zip(new_cells, p)]) for p in permutations2]
        best_match2 = distances.index(min(distances))
        for nc,sister_index in zip(new_cells, permutations2[best_match2]):
            mother_index = permutations1[best_match1].index(sister_index)
            # Removing last entry from mother cell and creating two new entries for the daughters
            mother = active_cells[mother_index][:-1] + ["Divided"]
            sister_1 = ["Daughter of #" + str(len(lineage)), active_cells[mother_index][-1].copy()]
            sister_2 = ["Daughter of #" + str(len(lineage)), nc]
            lineage.append(mother)
            active_cells[mother_index] = sister_1.copy()
            active_cells.append(sister_2)
    
    lineage.extend(active_cells)
    
    return lineage

def plotLineage(lineage, fields, name):
    
    # Finding start and end time-points
    t = fields.index("T") # Index of field "T"
    start, stop = int(min([l[0][t] for l in lineage if type(l[0]) == list])), int(max([l[-1][t] for l in lineage if type(l[-1]) == list]))
    
    # A lineage plot is composed of points marking the origin of cells and the final cells present, vertical lines that trace from the start to the
    # end of a cell's track, and horizontal lines marking mitotic events. Each cell has a unique X coordinate. Y coordinates depend on the time-point
    # an event takes place at, The plot is then built bottom-up, starting from the cells present at the last time-points.
    # Defining X coordinates. For each cell (after the first mother) its mother and sister are found, and the X coordinate of the sister pair is found
    # as mother_X +- 1. Then coordinates of cells coming before or after these cells are adjusted.
    
    x_coords = [0] + ["" for _ in range(len(lineage) - 1)]
    for cell in lineage:
        
        # Finding cell's mother and sister, if any
        cell_index = lineage.index(cell)
        if x_coords[cell_index] != "": # Cell was already processed when its sister was
            continue
        mother_index = int(cell[0][13:])
        sister_index = lineage.index([l for l in lineage if (l != cell) and (l[0] == cell[0])][0])
        
        # Shifting cells with lower x_coordinates than the cell, then inserting the latter in the list
        cell_x = x_coords[mother_index] - 1
        x_coords = [c - 1 if (c != "") and (c <= cell_x) else c for c in x_coords]
        x_coords[cell_index] = cell_x
        
        # Shifting cells with larger x_coordinates than the cell's sister, then inserting the latter in the list
        sister_x = x_coords[mother_index] + 1
        x_coords = [c + 1 if (c != "") and (c >= sister_x) else c for c in x_coords]
        x_coords[sister_index] = sister_x
    
    # Defining origin points Y coordinates (These points will be blue)
    y_origins = [l[0][t] if type(l[0]) == list else l[1][t] for l in lineage]
    
    # Defining the Y coordinates of points present at the last time point (These points will be red)
    x_final_cells = [x for x,l in zip(x_coords, lineage) if l[-1][t] == stop]
    y_final_cells = [stop for l in lineage if l[-1][t] == stop]
    
    # Defining vertical lines marking a cell's existance
    verticals = []
    for cell,x in zip(lineage, x_coords):
        y_start = cell[0][t] if type(cell[0]) == list else cell[1][t]
        y_stop = cell[-1][t] + 1 if type(cell[-1]) == list else cell[-2][t] + 1
        if y_stop > stop:
            y_stop = stop
        verticals.append([[x, x], [y_start, y_stop]])
    
    # Defining horizontal lines marking a daughter pair
    horizontals, processed_sisters = [], []
    for cell,cell_x in zip(lineage[1:], x_coords[1:]):
        if lineage.index(cell) in processed_sisters: # Horizontal line already defined when processing the sister cell
            continue
        pair_y = cell[0][t] if type(cell[0]) == list else cell[1][t]
        sister_index = lineage.index([l for l in lineage if (l != cell) and (l[0] == cell[0])][0])
        processed_sisters.append(sister_index)
        sister_x = x_coords[sister_index]
        horizontals.append([[cell_x, sister_x], [pair_y, pair_y]])
    
    # Adding text next to starting, ending, and mitotic time-points
    text_x, text_y, text = [x_coords[0] - 1, min(x_final_cells) - 1], [start + 0.25, stop + 0.25], [str(start - 1) + "h", str(stop - 1) + "h"]
    for h in horizontals:
        text_x.append(min(h[0]) - 1)
        text_y.append(h[1][0] + 0.25)
        text.append(str(int(h[1][0]) - 1) + "h")
    
    # Plotting the lineage
    plt.figure(1, figsize = (len(x_coords), 3 * len(x_coords)))
    plt.axis('off')
    plt.axis([min(x_coords) - 0.5, max(x_coords) + 0.5, stop + 1, start - 1]) # Inverting the y axis
    for v in verticals:
        plt.plot(v[0], v[1], "black", linewidth = 5)
    for h in horizontals:
        plt.plot(h[0], h[1], "black", linewidth = 5)
    plt.plot(x_coords, y_origins, "bo", markersize = 20, markeredgecolor = "black", markeredgewidth = 3, linewidth = 0)
    plt.plot(x_final_cells, y_final_cells, "ro", markersize = 20, markeredgecolor = "black", markeredgewidth = 3, linewidth = 0)
    for x,y,t in zip(text_x, text_y, text):
        plt.text(x, y, t, fontsize = 20, fontweight = "bold")
    plt.savefig(name[:-4] + "_InScaleLineage.png", dpi = 300)
    plt.clf()
    
    # As an in scale lineage plot may be quite long, a more compact one can be desirable. The dimension along the Y axis will therefore be
    # modified so that each major step (i.e. where text is) will be equally distant from the next/previous.
    conversion = {old - 0.25 : new for old,new in zip(sorted(text_y), range(len(text_y)))}
    plt.figure(1, figsize = (len(x_coords), 2 * len(text_y)))
    plt.axis('off')
    plt.axis([min(x_coords) - 0.5, max(x_coords) + 0.5, conversion[stop] + 1, conversion[start] - 1]) # Inverting the y axis
    for v in verticals:
        v[1] = [conversion[y] for y in v[1]]
        plt.plot(v[0], v[1], "black", linewidth = 5)
    for h in horizontals:
        h[1] = [conversion[y] for y in h[1]]
        plt.plot(h[0], h[1], "black", linewidth = 5)
    y_origins = [conversion[y] for y in y_origins]
    plt.plot(x_coords, y_origins, "bo", markersize = 20, markeredgecolor = "black", markeredgewidth = 3, linewidth = 0)
    y_final_cells = [conversion[y] for y in y_final_cells]
    plt.plot(x_final_cells, y_final_cells, "ro", markersize = 20, markeredgecolor = "black", markeredgewidth = 3, linewidth = 0)
    for x,y,t in zip(text_x, text_y, text):
        plt.text(x, conversion[y - 0.25], t, fontsize = 20, fontweight = "bold", verticalalignment = "center")
    plt.savefig(name[:-4] + "_CompactLineage.png", dpi = 300)
    plt.clf()

def exportLineageData(name, fields, lineage):
    
    # Export parsed and ordered lineage data to a .tsv file
    t = fields.index("T") # Index of field "T"
    output_text, output_name = [], name[:-4] + "_Data.tsv"
    start, stop = int(min([l[0][t] for l in lineage if type(l[0]) == list])), int(max([l[-1][t] for l in lineage if type(l[-1]) == list]))
    output_text.append("\t".join(["TimePoint"] + [str(t) for t in range(start, stop + 1)]))
    for f in range(1, len(fields)):
        output_text.append(fields[f])
        for cell,c in zip(lineage, range(len(lineage))):
            beginning = int(cell[0][t]) if type(cell[0]) == list else int(cell[1][t] - 1) # First time-point at which the cell is visible
            cell_id = "Cell-" + str(c) + "\t"
            spacer = "".join(["\t" for _ in range(beginning - start)])
            cell_entry = cell_id + spacer + "\t".join([str(cl[f]) if type(cl) == list else cl for cl in cell])
            output_text.append(cell_entry)
    output_text = "\n".join(output_text)
    output = open(output_name, "w")
    output.write(output_text)
    output.close()

###MAIN

print("Loading dependencies...")

try:
    from itertools import permutations as perm
    from matplotlib import pyplot as plt
    from os import listdir
except:
    print("One or more dependencies are not installed.\nAlso, make sure your terminal has been activated.")
    exit()

# Import data from each txt file, reconstruct the lineage, and extract info of daughter cell pairs
files = [f for f in listdir() if f[-4:] == ".txt"]
fields_of_interest = ["Area", "MeanRFP", "NormMeanRFP", "MeanGFP", "NormMeanGFP"]
for f in files:
    print("\nProcessing '" + f + "'")
    fields = open(f).read().split("\n")[0].split("\t")
    data = [[float(info) if info != "" else "" for info in entry.split("\t")] for entry in open(f).read().split("\n")[1:] if len(entry)]
    if "CellID" in fields:
        cell_lineage = reconstructLineage_Manual(data)
        fields.remove("CellID")
        fields.remove("MotherID")
    else:
        cell_lineage = reconstructLineage_Auto(data, fields)
    plotLineage(cell_lineage, fields, f)
    exportLineageData(f, fields, cell_lineage)